{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -q --upgrade pip\n",
    "!pip install -q wrapt --upgrade --ignore-installed\n",
    "!pip install -q tensorflow==2.1.0\n",
    "!pip install -q transformers==2.8.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import boto3\n",
    "import sagemaker\n",
    "import pandas as pd\n",
    "\n",
    "sess   = sagemaker.Session()\n",
    "bucket = sess.default_bucket()\n",
    "role = sagemaker.get_execution_role()\n",
    "region = boto3.Session().region_name\n",
    "\n",
    "sm = boto3.Session().client(service_name='sagemaker', region_name=region)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Specify the S3 Location of the Features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%store -r training_job_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(training_job_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "print('Previous Training Job Name: {}'.format(training_job_name))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "timestamp = '{}'.format(int(time.time()))\n",
    "\n",
    "compilation_job_name = '{}-{}'.format(training_job_name, timestamp)\n",
    "\n",
    "sm_client = boto3.client('sagemaker')\n",
    "data_shape = '{\"input_ids\":[1,128],\"input_mask\":[1,128],\"segment_ids\":[1,128]}'\n",
    "target_device = 'ml_c5'\n",
    "framework = 'TENSORFLOW' # TFLITE\n",
    "#framework_version = '2.1.0'\n",
    "model_path = 's3://{}/{}/output/model.tar.gz'.format(bucket, training_job_name)\n",
    "compiled_model_path = 's3://{}/{}/compiled-output/'.format(bucket, training_job_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TODO:  Work around the following error, if possible:\n",
    "```\n",
    "Incompatible Tensorflow model: The following operators are not implemented: {'StatefulPartitionedCall'}\n",
    "```\n",
    "\n",
    "Different forms of this error show up for `ml_c5`, `ml_inf1`, and `ml_p3` for our BERT model.  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "response = sm_client.create_compilation_job(\n",
    "    CompilationJobName=compilation_job_name,\n",
    "    RoleArn=role,\n",
    "    InputConfig={\n",
    "        'S3Uri': model_path,\n",
    "        'DataInputConfig': data_shape,\n",
    "        'Framework': framework\n",
    "    },\n",
    "    OutputConfig={\n",
    "        'S3OutputLocation': compiled_model_path,\n",
    "        'TargetDevice': target_device\n",
    "    },\n",
    "    StoppingCondition={\n",
    "        'MaxRuntimeInSeconds': 300\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.core.display import display, HTML\n",
    "\n",
    "display(HTML('<b>Review <a href=\"https://console.aws.amazon.com/sagemaker/home?region={}#/compilation-jobs/{}\">Compilation Job</a></b>'.format(region, compilation_job_name)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Poll every 10 sec\n",
    "while True:\n",
    "    response = sm_client.describe_compilation_job(CompilationJobName=compilation_job_name)\n",
    "    if response['CompilationJobStatus'] == 'COMPLETED':\n",
    "        break\n",
    "    elif response['CompilationJobStatus'] == 'FAILED':\n",
    "        raise RuntimeError('Compilation failed')\n",
    "    print('Compiling ...')\n",
    "    time.sleep(10)\n",
    "print('Done!')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Extract compiled model artifact\n",
    "compiled_model_path = response['ModelArtifacts']['S3ModelArtifacts']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TODO:  TFLite currently throwing an error related to GPUs, CUDA, and TensorRT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!aws s3 cp s3://$bucket/$training_job_name/output/model.tar.gz ./model.tar.gz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!tar -xzvf model.tar.gz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "converter = tf.lite.TFLiteConverter.from_saved_model('./tensorflow/saved_model/0/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "converter.post_training_quantize = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,\n",
    "                                       tf.lite.OpsSet.SELECT_TF_OPS\n",
    "                                      ]\n",
    "tflite_model = converter.convert()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "tflite_model_path = './tflite_optimized_model.tflite'\n",
    "\n",
    "model_size = open(tflite_model_path, \"wb\").write(tflite_model)\n",
    "\n",
    "print('\\nModel size reduced to %s bytes' % model_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following throws \n",
    "\n",
    "```\n",
    "RuntimeError: Regular TensorFlow ops are not supported by this interpreter. Make sure you apply/link the Flex delegate before inference.Node number 170 (FlexErf) failed to prepare.```\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "\n",
    "# Load TFLite model and allocate tensors.\n",
    "interpreter = tf.lite.Interpreter(model_path=tflite_model_path)\n",
    "\n",
    "interpreter.allocate_tensors()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get input and output tensors.\n",
    "input_details = interpreter.get_input_details()\n",
    "print('Input Tensor Details: %s' % input_details)\n",
    "\n",
    "output_details = interpreter.get_output_details()\n",
    "print('Output Tensor Details: %s' % output_details)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Test model on random input data.\n",
    "input_shape = input_details[0]['shape']\n",
    "input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)\n",
    "print('Input: %s' % input_data)\n",
    "interpreter.set_tensor(input_details[0]['index'], input_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "interpreter.invoke()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_data = interpreter.get_tensor(output_details[0]['index'])\n",
    "print('Prediction: %s' % output_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_python3",
   "language": "python",
   "name": "conda_python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
